# 機能設計書 74-DTensor API

## 概要

本ドキュメントは、TensorFlowにおけるDTensor API機能の設計について記述する。DTensorは、分散テンソルの作成・配置・レイアウト制御を提供するAPIであり、テンソルのシャーディング（分割配置）とレプリケーション（複製配置）をメッシュトポロジー上で管理する。

### 本機能の処理概要

**業務上の目的・背景**：大規模モデルの分散学習では、データ並列だけでなくモデル並列（テンソルの分割配置）が必要となる。DTensorは、テンソルの各次元をメッシュの各次元にマッピングするLayout（レイアウト）概念を導入し、SPMD（Single Program Multiple Data）ベースの分散計算を実現する。これにより、ユーザは分散通信を明示的に記述することなく、モデル並列・データ並列・パイプライン並列を柔軟に組み合わせた分散戦略を構築できる。

**機能の利用シーン**：大規模言語モデル（LLM）やビジョンモデルの分散学習、テンソルのシャーディング配置指定、デバイスメッシュ上での計算実行、既存tf.functionのDTensor対応化。

**主要な処理内容**：
1. Meshの定義: デバイスのトポロジー（次元名とサイズ）を記述
2. Layoutの定義: テンソル各次元のシャーディングスペックを記述
3. relayout: テンソルのレイアウト変更（シャーディング/レプリケーション切替）
4. pack/unpack: ローカルテンソルとDTensorの相互変換
5. default_mesh: デフォルトメッシュスコープの管理
6. call_with_layout: レイアウト指定でのTF関数呼び出し
7. 勾配の自動レイアウト伝播

**関連システム・外部連携**：XLA SPMD（TPU使用時）、DTensor C++ランタイム、NCCL通信

**権限による制御**：特段の権限制御はない。メッシュに含まれるデバイスへのアクセス権が必要。

## 関連画面

本機能はバックエンドの分散計算基盤であるため、直接的な関連画面は存在しない。

## 機能種別

分散計算API / テンソル配置管理

## 入力仕様

### 入力パラメータ（Mesh）

| パラメータ名 | 型 | 必須 | 説明 | バリデーション |
|-------------|-----|-----|------|---------------|
| dim_names | List[str] | Yes | 次元名のリスト | global_device_idsの次元数と一致すること |
| global_device_ids | np.ndarray | Yes | グローバルデバイスID配列 | 非空、連番であること |
| local_device_ids | List[int] | Yes | ローカルデバイスIDリスト | global_device_idsの部分集合 |
| local_devices | List[DeviceSpec/str] | Yes | ローカルデバイスリスト | local_device_idsと同数 |
| mesh_name | str | No | メッシュ名 | デフォルト空文字列 |
| global_devices | List[DeviceSpec/str] | No | グローバルデバイスリスト | マルチクライアント時に設定 |
| use_xla_spmd | bool | No | XLA SPMDの使用 | TPUのみ有効 |

### 入力パラメータ（Layout）

| パラメータ名 | 型 | 必須 | 説明 | バリデーション |
|-------------|-----|-----|------|---------------|
| sharding_specs | List[str] | Yes | シャーディングスペックのリスト | メッシュ次元名またはUNSHARDED/MATCH |
| mesh | Mesh | Yes | メッシュオブジェクト | Meshインスタンスであること |

### 入力データソース

プログラムコードから直接指定されるテンソルおよびレイアウト情報

## 出力仕様

### 出力データ

| 項目名 | 型 | 説明 |
|--------|-----|------|
| DTensor | Tensor | レイアウト属性を持つ分散テンソル |
| Layout | Layout | テンソルのレイアウト情報 |
| Mesh | Mesh | デバイスメッシュ情報 |
| component tensors | List[Tensor] | unpack時のローカルコンポーネントテンソル |

### 出力先

後続の計算グラフオペレーションへの入力

## 処理フロー

### 処理シーケンス

```
1. メッシュの定義
   └─ Mesh(dim_names, global_device_ids, local_device_ids, local_devices)
2. レイアウトの定義
   └─ Layout(sharding_specs, mesh)
3. DTensorの作成
   ├─ relayout(tensor, layout): 既存テンソルのレイアウト変更
   ├─ pack(tensors, layout): ローカルテンソルからDTensor作成
   └─ call_with_layout(fn, layout): レイアウト指定で関数実行
4. DTensorの使用
   └─ 通常のTF操作を実行（SPMDで自動分散）
5. DTensorの分解
   └─ unpack(dtensor): ローカルコンポーネントテンソルに分解
```

### フローチャート

```mermaid
flowchart TD
    A[Mesh定義] --> B[Layout定義]
    B --> C{DTensor作成方法}
    C -->|relayout| D[既存テンソルのレイアウト変更]
    C -->|pack| E[ローカルテンソルをDTensorに結合]
    C -->|call_with_layout| F[レイアウト指定で関数実行]
    D --> G[SPMD分散計算実行]
    E --> G
    F --> G
    G --> H{結果取得方法}
    H -->|DTensorのまま| I[後続計算に使用]
    H -->|unpack| J[ローカルテンソルに分解]
```

## ビジネスルール

### 業務ルール

| ルールNo | ルール名 | 内容 | 適用条件 |
|---------|---------|------|---------|
| BR-74-01 | メッシュデバイスID連番 | global_device_idsは連番でなければならない | Mesh初期化時 |
| BR-74-02 | シャーディング一意性 | Layout内の各シャーディングスペックはメッシュ次元名として一意であること（UNSHARDED/MATCHを除く） | Layout初期化時 |
| BR-74-03 | デバイス型統一 | メッシュ内の全デバイスは同一デバイスタイプであること | Mesh初期化時 |
| BR-74-04 | XLA SPMD TPU制限 | use_xla_spmd=TrueはTPUメッシュでのみ有効 | Mesh初期化時 |
| BR-74-05 | pack/unpackはEagerのみ | pack()とunpack()はEagerモードでのみ使用可能 | 関数呼び出し時 |
| BR-74-06 | relayoutの方向制限 | replicated→shardedまたはsharded→replicatedのみ対応。"x,y"→"z,y"のようなメッシュ次元変更は非対応 | relayout呼び出し時 |
| BR-74-07 | スカラーのレイアウト | スカラーDTensorは完全レプリケーションのみ有効 | pack時 |

### 計算ロジック

- relayoutの内部実装: Split操作（replicated→sharded）またはAllToAll操作（sharded→replicated）に展開
- メッシュストライド計算: shape=[a,b,c,d]のとき、strides=[b*c*d, c*d, d, 1]

## データベース操作仕様

本機能はデータベース操作を行わない。

## エラー処理

### エラーケース一覧

| エラーコード | エラー種別 | 発生条件 | 対処方法 |
|------------|----------|---------|---------|
| ValueError | メッシュ不正 | global_device_idsが空 | 非空の配列を渡す |
| ValueError | メッシュ不正 | global_device_idsが非連番 | 連番の配列を渡す |
| ValueError | 次元不一致 | dim_namesの数とglobal_device_idsの次元数が異なる | 一致するよう修正 |
| ValueError | デバイス重複 | ローカルデバイスに重複がある | 一意なデバイスを指定 |
| ValueError | レイアウト不正 | sharding_specsにメッシュに存在しない次元名 | メッシュ次元名またはUNSHARDEDを使用 |
| ValueError | レイアウト不正 | sharding_specsで同一メッシュ次元が複数回使用 | 各メッシュ次元は1回のみ使用 |
| ValueError | XLA SPMD不正 | use_xla_spmd=TrueでTPU以外のメッシュ | TPUメッシュでのみ使用 |
| RuntimeError | モード不正 | pack/unpackをGraphモードで呼び出し | Eagerモードで使用 |

### リトライ仕様

自動リトライ機能は提供されない。

## トランザクション仕様

該当なし。

## パフォーマンス要件

- relayoutでsharded→replicatedの変換はAllToAll通信が発生するためコスト高
- replicated→shardedの変換はSplit操作のみで通信不要
- use_xla_spmd=TrueによりTPU上でXLAコンパイラのSPMD最適化を活用可能

## セキュリティ考慮事項

- メッシュに含まれるデバイスへのアクセス制御は、TensorFlowのデバイスアクセス設定に依存

## 備考

- DTensorデバイスはシングルトンとして管理される（_dtensor_singleton）
- スレッドロックによる初期化の排他制御
- 勾配計算のrelayout/relayout_likeは自動登録される（@ops.RegisterGradient）
- MATCHシャーディングスペックはrelayoutでのアイデンティティ操作を指定

---

## コードリーディングガイド

本機能を理解するために参照すべきファイルと、推奨する読み解き順序を以下に示す。

### 推奨読解順序

#### Step 1: データ構造を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 1-1 | layout.py | `tensorflow/dtensor/python/layout.py` | Meshクラス（53-345行目）のデバイスメッシュ定義を理解する |
| 1-2 | layout.py | `tensorflow/dtensor/python/layout.py` | Layoutクラス（351-551行目）のシャーディングスペック定義を理解する |

**読解のコツ**:
- Meshは`_pywrap_dtensor_device.Mesh`のPythonラッパー。C++実装をpybind11で公開。
- **32-34行目**: UNSHARDED = 'unsharded'、MATCH = 'match'の定数定義
- **42行目**: MeshDimensionはnamedtupleで(name, size)のペア
- **45-50行目**: ストライド計算関数_compute_mesh_strides
- **75-222行目**: Mesh.__init__でデバイスID、デバイス名、次元名の検証
- **128-139行目**: global_device_idsの連番検証
- **188-191行目**: ローカルデバイスの重複検証
- **200-211行目**: デバイスタイプの統一検証（CPU/GPU/TPU単一型のみ）
- **397-434行目**: Layout.__init__でシャーディングスペックの一意性検証

#### Step 2: エントリーポイントを理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 2-1 | api.py | `tensorflow/dtensor/python/api.py` | DTensor APIの主要関数を理解する |

**主要処理フロー**:
1. **37-64行目**: call_with_layout - レイアウト指定での関数実行（Eagerモードではdefault_mesh+_default_layout、Graphモードではrelayoutで対応）
2. **89-112行目**: default_mesh - デフォルトメッシュスコープの設定（contextmanager）
3. **147-160行目**: is_dtensor - テンソルがDTensorかどうかの判定
4. **167-188行目**: copy_to_mesh - テンソルをDTensorデバイスにコピー（内部的にrelayout呼出し）
5. **191-339行目**: pack - ローカルテンソルをDTensorに結合
6. **342-372行目**: unpack - DTensorをローカルコンポーネントに分解
7. **411-449行目**: relayout - テンソルのレイアウト変更（gen_dtensor_ops.relayoutを呼出し）
8. **452-502行目**: relayout_like - 他のテンソルのレイアウトに合わせる

#### Step 3: 勾配の自動レイアウト伝播を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 3-1 | api.py | `tensorflow/dtensor/python/api.py` | 勾配登録関数（538-568行目） |

**読解のコツ**: `@ops.RegisterGradient("Relayout")`でRelayout Opの勾配がrelayout_likeとして登録される。これにより、逆伝播時に勾配テンソルが入力テンソルと同じレイアウトに自動配置される。

#### Step 4: シングルトンデバイス管理を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 4-1 | api.py | `tensorflow/dtensor/python/api.py` | _dtensor_singleton（30-31行目）と_dtensor_device()関数（519-523行目） |

### プログラム呼び出し階層図

```
DTensor API (api.py)
    │
    ├─ relayout(tensor, layout)
    │      └─ gen_dtensor_ops.relayout(tensor, layout_str)
    │             └─ Relayout C++ Op → SPMD展開 → Split / AllToAll
    │
    ├─ relayout_like(tensor, layout_tensor)
    │      └─ gen_dtensor_ops.relayout_like(input, layout_input)
    │             └─ RelayoutLike C++ Op
    │
    ├─ pack(tensors, layout)
    │      └─ _dtensor_device().pack(tensors, layout)
    │             └─ DTensorDevice C++ 実装
    │
    ├─ unpack(tensor)
    │      └─ _dtensor_device().unpack(tensor)
    │             └─ DTensorDevice C++ 実装
    │
    ├─ default_mesh(mesh)
    │      └─ _dtensor_device()._experimental_default_mesh(mesh)
    │
    └─ call_with_layout(fn, layout)
           ├─ [Eager] default_mesh → _default_layout → fn()
           └─ [Graph] fn() → relayout()
```

### データフロー図

```
[入力]                       [処理]                           [出力]

Mesh定義 ──────────▶   Layout定義                         DTensor
  dim_names              sharding_specs ──▶ relayout() ──▶  (分散テンソル)
  device_ids             mesh

ローカルテンソル ─────▶  pack(tensors, layout) ─────────▶  DTensor

DTensor ──────────▶  unpack(dtensor) ──────────────────▶  ローカルテンソル群
```

### 関連ファイル一覧

| ファイル | パス | 種別 | 役割 |
|---------|------|------|------|
| api.py | `tensorflow/dtensor/python/api.py` | ソース | DTensor主要APIエントリーポイント |
| layout.py | `tensorflow/dtensor/python/layout.py` | ソース | Mesh/Layoutデータ構造定義 |
| dtensor_device.py | `tensorflow/dtensor/python/dtensor_device.py` | ソース | DTensorデバイスのPythonラッパー |
| gen_dtensor_ops | 自動生成 | ソース | C++ Op呼び出しブリッジ |
| _pywrap_dtensor_device | C++バインディング | ソース | Mesh/Layoutの基底C++実装 |
| layout.proto | `tensorflow/dtensor/proto/layout.proto` | Proto | レイアウト/メッシュのProtobuf定義 |
